#################################################################################################
#
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################

"""
Emitter for Sm100 Epilogue Visitor
"""

from cutlass_library import DataType, DataTypeTag, EpilogueScheduleTag, OpcodeClassTag
from cutlass_cppgen.backend.library import to_blackwell_threadblock_shape
from cutlass_cppgen.backend import GemmOperationUniversal
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
from cutlass_cppgen.backend.evt.ir.node import TupleEmitter


class Sm100CollectiveEpilogue:
    def __init__(self, tile_description,
                 kernel_schedule,
                 epilogue_schedule,
                 element_accumulator,
                 element_d,
                 fusion_callbacks) -> None:

        self.cta_tile_mnk, _ = to_blackwell_threadblock_shape(tile_description, tile_description.cluster_shape, kernel_schedule)
        self.element_accumulator = element_accumulator
        if fusion_callbacks.dag_ir.has_node("C"):
            self.element_c = fusion_callbacks.dag_ir.get_node_meta("C").element
        else:
            self.element_c = DataType.void
        self.element_d = element_d
        self.schedule = epilogue_schedule
        self.fusion_callbacks = fusion_callbacks
        self.opclass = tile_description.math_instruction.opcode_class

    @property
    def CtaTileMNK(self) -> str:
        """
        The threadblock shape
        """
        return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"

    @property
    def EpilogueTileType(self) -> str:
        """
        The epilogue tile type
        """
        return "cutlass::epilogue::collective::EpilogueTileAuto"

    @property
    def Schedule(self) -> str:
        return EpilogueScheduleTag[self.schedule]

    def emit(self):
        tuple_emitter = TupleEmitter("int64_t")
        stride_D_str = self.fusion_callbacks.dag_ir.get_node_meta("D").underlying_impl.stride_mnl
        stride_C_str = stride_D_str
        if self.fusion_callbacks.dag_ir.has_node("C"):
            stride_C_str = self.fusion_callbacks.dag_ir.get_node_meta("C").underlying_impl.stride_mnl

        callback_decl, callback_name = self.fusion_callbacks.emit()
        return callback_name, f"""
using EpilogueDescriptor = cutlass::epilogue::collective::detail::Sm100EpilogueDescriptor<
  {OpcodeClassTag[self.opclass]},
  {self.CtaTileMNK}, {self.EpilogueTileType},
  {DataTypeTag[self.element_accumulator]}, {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
  {self.Schedule}, {stride_C_str}, {stride_D_str},
  false /* IsPerColScaleSupported */,
  false /* IsBlockScaleSupported */
>;
{callback_decl}
"""


class Sm100Emitter:
    def __init__(self, operation: GemmOperationUniversal, graph) -> None:
        fusion_callbacks = FusionCallbacks(graph, cc=100, emit_CD=False)

        self.collective_epilogue = Sm100CollectiveEpilogue(
            tile_description=operation.tile_description,
            kernel_schedule=operation.tile_description.kernel_schedule,
            epilogue_schedule=operation.tile_description.epilogue_schedule,
            element_accumulator=operation.tile_description.math_instruction.element_accumulator,
            element_d=fusion_callbacks.dag_ir.get_node_meta("D").element,
            fusion_callbacks=fusion_callbacks
        )

    def emit(self):
        return self.collective_epilogue.emit()
